## Eos, Dawn of Light -- A Space Opera
## Copyright (c) 2007 Casey Duncan and contributors
## See LICENSE.txt for licensing details

# A.I.
# $Id: ai.py 327 2007-11-13 07:29:12Z cduncan $

import sys
import time
import math
from pygame.sprite import Group, GroupSingle
from pygame.locals import *
import ode
import game
import body
import vector
import sprite
from vector import diagonal, halfcircle, rightangle, fullcircle
from vessel import Control, Vessel, DirectionalThrusters
import staticbody

class BasicAI(Control):
	"""Basic AI Control"""

	# Maximum distance to target when we have an objective
	target_max_distance_with_objective = 2000

	evade_min_health = 0.66
	evade_max_health = 0.75
	evade_damage_timeout = 2000
	evade_max_distance = 350

	update_interval = 100 # millis between updates
	
	def __init__(self, vessel, target=None, objective=None, sensor=None):
		Control.__init__(self)
		self.vessel = vessel
		# override vessel collision handler with our own
		self.vessel.collide = self.collide
		# cache some vessel stats
		self.target = GroupSingle()
		if target:
			if hasattr(target, 'sprite'):
				target = target.sprite
			if target.alive():
				self.target.add(target)
		self.objective = GroupSingle()
		if objective:
			self.objective.add(objective)
		self.steerfunc = self.standoff
		self.close_vessels = Group()
		self.proximity_radius = self.vessel.collision_radius * 5
		self.sensor = sensor
		self.target_time = 0
		self.next_update = 0
	
	def seek_position(self, target_position, predict_ahead=2.0):
		"""Return desired velocity towards a fixed position

		>>> game.init()
		>>> ai = BasicAI(Vessel(), None)
		>>> ai.vessel.max_speed = 10
		>>> ai.vessel.position = (100, 0)
		>>> vel = ai.seek_position((-100, 0))
		>>> vel.x
		-10.0
		>>> vel.y
		0.0
		"""
		position = self.vessel.position + self.vessel.velocity * predict_ahead
		return vector.normal(target_position - position) * self.vessel.max_speed

	def seek(self, target=None):
		"""Return desired velocity towards a fixed target

		>>> game.init()
		>>> target = Vessel()
		>>> ai = BasicAI(Vessel(), target)
		>>> ai.vessel.max_speed = 33.0
		>>> ai.vessel.position = Vector2D(0, 0)
		>>> target.position = Vector2D(0, 400)
		>>> head, vel = ai.seek()
		>>> vel.x
		0.0
		>>> vel.y
		33.0
		>>> head == rightangle
		True
		"""
		position = self.vessel.position
		if target is None:
			target_position = self.target.sprite.position
		else:
			target_position = target.position
		velocity = self.seek_position(target_position)
		return vector.radians(velocity), velocity

	def flee(self):
		"""Return desired velocity away from our target

		>>> game.init()
		>>> target = Vessel()
		>>> ai = BasicAI(Vessel(), target)
		>>> ai.vessel.max_speed = 33.0
		>>> ai.vessel.position = Vector2D(0, 0)
		>>> target.position = Vector2D(0, 400)
		>>> head, vel = ai.flee()
		>>> vel.x
		-0.0
		>>> vel.y
		-33.0
		>>> head == rightangle + halfcircle
		True
		"""
		heading, velocity = self.seek()
		return (heading + halfcircle) % fullcircle, -velocity

	def predict_intercept(self, target, target_predict=0.75, self_predict=0.75):
		"""Return an approach vector to the target

		predict_ahead -- Number of seconds to predict the position ahead when
		determining the time and distance to target. Low values encourage
		"undershoot" behavior, high values "overshoot"

		>>> game.init()
		>>> target = Vessel()
		>>> ai = BasicAI(Vessel(), target)
		>>> ai.vessel.max_speed = 100
		>>> ai.vessel.position = Vector2D(0, 0)
		>>> target.position = Vector2D(150, 0)
		>>> target.velocity = Vector2D(20, 0)
		>>> pos = ai.predict_intercept(target)
		>>> int(pos.x)
		180
		>>> int(pos.y)
		0
		"""
		velocity = self.vessel.velocity
		position = self.vessel.position
		predict_ahead = vector.distance(position + velocity * self_predict,
			target.position + target.velocity * target_predict) / self.vessel.max_speed
		return (target.position + target.velocity * min(predict_ahead, 1.0)) - (
			position + velocity * self_predict)

	def pursue(self, predict_ahead=0.75):
		"""Return desired velocity towards where we predict
		our target to be

		>>> game.init()
		>>> target = Vessel()
		>>> ai = BasicAI(Vessel(), target)
		>>> ai.vessel.max_speed = 100
		>>> ai.vessel.position = Vector2D(0, 0)
		>>> target.position = Vector2D(150, 0)
		>>> target.velocity = Vector2D(-200, 0)
		>>> head, vel = ai.pursue()
		>>> vel.x
		-100.0
		>>> vel.y
		0.0
		>>> head
		0.0
		"""
		target = self.target.sprite
		position = self.vessel.position
		if not target.velocity and vector.distance(
			position, target.position) < target.collision_radius * 1.5:
			# close in to stationary target
			return vector.radians(target.position - position), target.velocity
		approach = self.predict_intercept(target, self_predict=0.4)
		return vector.radians(approach), vector.clamp(approach * game.fps, self.vessel.max_speed)
	
	def standoff(self):
		"""Pursue a target, keeping our distance"""
		target = self.target.sprite
		position = self.vessel.position
		to_target = target.position - position
		approach = self.predict_intercept(target)
		distance = vector.length(target.position - position)
		if not target.velocity:
			# stationary target
			heading = vector.radians(to_target)
			desired_dist = target.collision_radius + self.vessel.collision_radius * 5
			if distance < desired_dist:
				return self.vessel.heading, vector.vector2()
		elif self.vessel.is_friendly(target):
			heading = target.heading
			desired_dist = self.vessel.collision_radius * 4 + target.collision_radius * 3
		else:
			desired_dist = (self.vessel.collision_radius + target.collision_radius + 
				self.vessel.standoff_distance)
			heading = (vector.radians(to_target) * desired_dist * 4
				+ vector.radians(approach) * distance) / (desired_dist * 4 + distance)
		if distance > desired_dist:
			# Far from target, catch up as fast as we can
			return vector.radians(approach), vector.clamp(
				approach * game.fps, self.vessel.max_speed)
		else:
			# close to target, keep a distance
			return heading, vector.clamp(-(approach * 
				(desired_dist - vector.length(approach))), self.vessel.max_speed)

	def evade(self):
		"""Return desired velocity away from where we predict
		our target to be. Under evasion, we still turn toward
		the target, under the assumption that we will resume
		a pursuit or seek afterward.

		>>> game.init()
		>>> target = Vessel()
		>>> ai = BasicAI(Vessel(), target)
		>>> ai.vessel.max_speed = 42
		>>> ai.vessel.position = Vector2D(0, 0)
		>>> target.position = Vector2D(0, 150)
		>>> target.velocity = Vector2D(0, -200)
		>>> head, vel = ai.evade(target)
		>>> vel.x
		-0.0
		>>> vel.y
		42.0
		>>> head == rightangle
		True
		"""
		heading, seek_velocity = self.pursue()
		velocity = vector.unit(
			vector.radians(seek_velocity) + rightangle * 1.5) * self.vessel.max_speed
		return heading, velocity

	def avoid_vessels(self):
		"""Return a vector away from other nearby vessels to avoid stacking up 
		"""
		avoid_vec = vector.vector2()
		# Compute the center vector amongst close vessels and avoid it
		for vessel in list(self.close_vessels):
			proximity = max(vector.distance(
				self.vessel.position, vessel.position) - vessel.collision_radius, 0)
			if proximity < self.proximity_radius * 4:
				away = ((self.vessel.position + self.vessel.velocity / 5) - 
				        (vessel.position + vessel.velocity / 5) - 
						(vessel.collision_radius + self.vessel.collision_radius)) * (
						 vessel.mass / self.vessel.mass)
				avoid_vec += away / (vector.length(away) / self.proximity_radius or 0.001)
			elif proximity > self.proximity_radius * 5:
				# Other vessel is not considered "close" anymore
				self.close_vessels.remove(vessel)
		return avoid_vec

	def steer(self, desired_heading, desired_velocity):
		heading_diff = desired_heading - self.vessel.heading
		if heading_diff > halfcircle:
			heading_diff -= fullcircle
		elif heading_diff < -halfcircle:
			heading_diff += fullcircle

		turn_rate = self.vessel.turn_rate
		max_turn_rate = self.vessel.directional_thrusters.max_turn_rate
		if (heading_diff > turn_rate * (self.update_interval / 300.0)
			or heading_diff < -turn_rate * (self.update_interval / 300.0)):
			self.turn = (heading_diff / max_turn_rate)  * 1000 / self.update_interval
			self.turn = max(min(self.turn, 1), -1)
		else:
			self.turn = 0
		if self.vessel.velocity != desired_velocity:
			thrust_dir = vector.radians(
				desired_velocity - self.vessel.velocity) - self.vessel.heading
			self.thrust = thrust_dir < diagonal or thrust_dir > fullcircle - diagonal
			self.fw_maneuver = not self.thrust and (
				thrust_dir < rightangle or thrust_dir > fullcircle - rightangle)
			self.right_maneuver = halfcircle > thrust_dir + diagonal > rightangle
			self.bw_maneuver = halfcircle + rightangle > thrust_dir + diagonal > halfcircle
			self.left_maneuver = fullcircle > thrust_dir + diagonal > halfcircle + rightangle
		else:
			self.thrust = False
			self.right_maneuver = False
			self.left_maneuver = False
			self.bw_maneuver = False
	
	def acquire_target(self):
		"""auto-select the approriate target"""
		if (self.target and self.objective and vector.distance(
			self.vessel.position, self.target.sprite.position) >
			self.target_max_distance_with_objective
			and self.target not in self.objective):
			self.target.empty()
		if (not self.target or game.time > self.target_time
			or self.target.sprite.explosion is not None):
			# acquire a target vessel
			if self.sensor is not None:
				# Use the sensor to find a target
				if (self.sensor.closest_vessel and (not self.objective or 
					vector.distance(self.vessel.position, 
					                self.sensor.closest_vessel.sprite.position) < 
					self.target_max_distance_with_objective)):
					self.set_target(self.sensor.closest_vessel)
					self.sensor.disable()
				elif not self.sensor.enabled:
					self.sensor.enable()
			if not self.target:
				if not self.objective:
					self.choose_objective()
				# head for the objective looking for other targets
				self.set_target(self.objective, timeout=0)
	
	def choose_objective(self):
		"""Set the ship's objective, which by default is to head to the nearest
		friendly or neutral planet
		"""
		def distance(planet):
			return vector.distance(self.vessel.position, planet.position)
		planets = sorted(game.map.planets, key=distance)
		# Try for the closest friendly planet first
		for planet in planets:
			if planet.base is not None and self.vessel.is_friendly(planet.base.owner):
				self.objective.add(planet)
				return
		# Try for the closest neutral planet next
		for planet in planets:
			if planet.base is None:
				self.objective.add(planet)
				return
		# Just head for the closest planet at all
		self.objective.add(planets[0])
	
	def set_target(self, target, timeout=3000):
		"""Set the target for the ai, timing out in timeout seconds
		at which time it will acquire another target
		"""
		self.target.add(target)
		if timeout is not None:
			self.target_time = game.time + timeout
		else:
			self.target_time = sys.maxint
	
	def select_steerfunc(self):
		"""Select the appropirate steering function"""
		if (self.vessel.health < self.evade_min_health 
			and game.time - self.vessel.damage_time < self.evade_damage_timeout
			and not self.vessel.is_friendly(self.target.sprite)):
			self.steerfunc = self.evade
		elif self.steerfunc is not self.standoff and (
			self.vessel.health > self.evade_max_health
			or vector.distance(self.target.sprite.position, self.vessel.position) 
				> self.evade_max_distance):
			self.steerfunc = self.standoff
		else:
			self.steerfunc = self.standoff

	def update(self):
		if game.time > self.next_update:
			self.next_update = game.time + self.update_interval
			self.acquire_target()
			self.select_steerfunc()
			# steering
			desired_heading, desired_velocity = self.steerfunc()
			desired_velocity += self.avoid_vessels()
			self.steer(desired_heading, desired_velocity)
		# Fire all targeted weapons
		for i, weapon in enumerate(self.vessel.weapons):
			self.weapons[i] = weapon.targeted
	
	def collide(self, other, contacts):
		"""Detect contact with other vessels to avoid stacking

		>>> game.init()
		>>> ai = BasicAI(Vessel(), None)
		>>> len(ai.close_vessels)
		0
		>>> v = Vessel()
		>>> ai.collide(v, [])
		>>> len(ai.close_vessels)
		1
		>>> ai.collide(object(), [])
		>>> len(ai.close_vessels)
		1
		"""
		if isinstance(other, Vessel):
			# Keep track of other vessels were are stacked on
			self.close_vessels.add(other)


class EvaderAI(BasicAI):
	
	evade_min_health = 0.5
	evade_max_health = 0.7

	def evade(self):
		"""Turn away from the target"""
		heading, seek_velocity = self.pursue()
		velocity = vector.unit(
			vector.radians(seek_velocity) + rightangle) * self.vessel.max_speed * 10
		return vector.radians(velocity), velocity


class AggroAI(BasicAI):

	evade_min_health = 0.4
	evade_max_health = 0.5

	def evade(self):
		"""Head through the target"""
		heading = self.target.sprite.heading + halfcircle
		return heading, vector.unit(heading) * self.vessel.max_speed * 10


class Standoffish(BasicAI):
	
	orbit = vector.vector2(1, -0.25)

	def standoff(self):
		heading, velocity = BasicAI.standoff(self)
		return heading, velocity * self.orbit
	
	def select_steerfunc(self):
		self.steerfunc = self.standoff


class HitAndRun(BasicAI):
	
	evade_min_health = 0.7
	evade_max_health = 0.75
	evade_damage_timeout = 1200
	evade_max_distance = 250

	evade = BasicAI.flee


class AssaultAI(BasicAI):

	def select_steerfunc(self):
		target = self.target.sprite
		if isinstance(target, staticbody.Planet):
			self.steerfunc = self.seek
		elif self.vessel.is_friendly(target) or self.vessel.health > .75:
			self.steerfunc = self.pursue
		else:
			self.steerfunc = self.flee
	
	def acquire_target(self):
		if isinstance(self.objective.sprite, staticbody.Planet):
			# When targetting a planet, always go there
			self.target.add(self.objective)
		else:
			BasicAI.acquire_target(self)


class GnatAI(BasicAI):
	"""Gnat AI control"""

	is_returning = False # Is returning to the objective
	had_target = False

	def steer(self, desired_heading, desired_velocity):
		BasicAI.steer(self, desired_heading, desired_velocity)
		# gnats only have maneuvering thrusters
		self.fw_maneuver = self.fw_maneuver or self.thrust
		self.thrust = False

	def acquire_target(self):
		if self.target and not self.vessel.is_friendly(self.target.sprite):
			self.had_target = True
		elif self.had_target:
			# Target was destroyed
			self.is_returning = True
		if self.is_returning or (self.vessel.energy < (self.vessel.max_energy * 0.1)
			or self.vessel.health < 0.25):
			# Return to the objective if our energy is expended
			self.target.add(self.objective)
			self.is_returning = True
		elif (not self.target or self.vessel.is_friendly(self.target.sprite)) and self.objective:
			# Target whatever the objective is targeting
			self.target.add(self.objective.sprite.control.target)
		if not self.target:
			if self.objective:
				# head for the objective
				self.target.add(self.objective)
			else:
				# Commit suicide if our objective dies
				self.vessel.explode('hit.wav')
				self.target.add(game.map.planets)
	
	def select_steerfunc(self):
		if (self.target 
			and vector.distance(self.target.sprite.position, self.vessel.position) < 350):
			# At close range, circle the target
			self.steerfunc = self.orbit
		else:
			self.steerfunc = self.pursue

	def orbit(self):
		"""Steer in an orbit around our target"""
		heading, seek_velocity = self.pursue(predict_ahead=2.0)
		velocity = vector.unit(
			vector.radians(seek_velocity) + diagonal) * self.vessel.max_speed
		return heading, seek_velocity
	
	def collide(self, other, contacts):
		if self.is_returning and other in self.objective:
			if hasattr(self.objective.sprite, 'fighter_bay'):
				self.objective.sprite.fighter_bay.dock(self.vessel)
				self.is_returning = False
				self.had_target = False
				self.target.empty()
		else:
			BasicAI.collide(self, other, contacts)

def control(vessel, target=None, objective=None, sensor=None):
	"""Return the appropriate ai control instance for the specified vessel"""
	if sensor is None and vessel.category is not None:
		sensor = Sensor(vessel, 10000, body.everything & ~vessel.category)
		sensor.disable()
	ai_class = globals()[vessel.ai_name]
	return ai_class(vessel, target, objective, sensor)

class AIVessel(Vessel):
	"""Vessel under ai control"""

	def __init__(self, target=None, category=None, objective=None, sensor=None,
		ai='BasicAI', **kw):
		"""
		target -- Sprite to intercept and attack if foe.
		category -- Vessel category, determines friends and foes.
		objective -- Long term objective sprite or base
		sensor -- Sensor object to be used (if omitted one is created just for this vessel)
		ai -- class name of ai personality.
		"""
		Vessel.__init__(self, ai=ai, **kw)
		if category is not None:
			self.setup_collision(category, body.everything & ~body.shot)
		self.control = control(self, target, objective, sensor)
		# Make sure we are behind the local player
		self.layer.to_back(self)
	


class Sensor:
	"""Detects bodies in the vicinity of a host ship"""

	detected = None # Sprite group of bodies detected by sensor
	enabled = True

	def __init__(self, vessel, radius, detect_bits=body.everything, exclude=()):
		"""Create a sensor which remains centered around the specified 
		vessel and detects bodys with categories intersecting detect_bits
		within the radius specified. Bodies in exclude (a sequence of bodies
		such as a sprite group) are never detected.

		>>> game.init()
		>>> v = Vessel()
		>>> s = Sensor(v, 100, 1)
		>>> s.vessel is v
		True
		>>> s.radius
		100
		>>> s.detect_bits
		1
		"""
		self.vessel = vessel
		self.radius = radius
		self.exclude = exclude
		self.geom = ode.GeomSphere(game.collision_space, self.radius)
		self.geom.setBody(vessel.body) # Move with the vessel
		self.geom.parent = self # attach ourself for callback purposes
		self.geom.setCategoryBits(body.nothing) # nothing collides with us
		self.detect_bits = detect_bits
		self.geom.setCollideBits(detect_bits)
		self.last_sweep = None
		self._detected = []
		self.closest_vessel = GroupSingle()
		self._closest_dist = sys.maxint
		self._closest_type = None
	
	def disable(self):
		"""Turn off sensor
		
		>>> game.init()
		>>> v = Vessel()
		>>> s = Sensor(v, 100, 1)
		>>> s.enabled
		True
		>>> s.disable()
		>>> s.enabled
		False
		"""
		self.geom.disable()
		self._detected = []
		self.closest_vessel.empty()
		self.enabled = False
	
	def enable(self):
		"""Turn on sensor

		>>> game.init()
		>>> v = Vessel()
		>>> s = Sensor(v, 100, 1)
		>>> s.enabled
		True
		>>> s.disable()
		>>> s.enabled
		False
		>>> s.enable()
		>>> s.enabled
		True
		"""
		self.geom.enable()
		self.enabled = True
				
	def collide(self, other, contacts):
		"""Detect a body in our radius
		
		>>> game.init()
		>>> v = Vessel()
		>>> s = Sensor(v, 100, 1)
		>>> o = Vessel()
		>>> s.collide(o, [])
		>>> s.detected.sprites() == [o]
		True
		>>> s.collide(o, [])
		>>> s.detected.sprites() == [o]
		True
		>>> s.closest_vessel.sprite is o
		True
		>>> o.kill()
		>>> s.detected.sprites()
		[]
		"""
		if other in self.exclude or other is self.vessel:
			return
		if self.last_sweep != game.frame_no and self.detected:
			# This is a new frame, start a new sensor sweep
			self._detected = []
			self._closest_dist = sys.maxint
			self._closest_incidental = False
			self.closest_vessel.empty()
			self.last_sweep = game.frame_no
		distance = vector.distance(self.vessel.position, other.position)
		self._detected.append((distance, other))
		if (isinstance(other, Vessel) and distance < self._closest_dist 
			or self._closest_incidental and other.incidental):
				self.closest_vessel.sprite = other
				self._closest_dist = distance
				self._closest_incidental = other.incidental
	
	@property
	def detected(self):
		"""Return a list of detected bodies in order by distance from vessel"""
		self._detected.sort()
		return [body for d, body in self._detected]

	def setDetect(self, detect_bits):
		self.geom.setCollideBits(detect_bits)
	
	def setRadius(self, radius):
		self.geom.setRadius(radius)


class SharedSensor(Sensor):
	"""Sensor shared between multiple AIs"""

	def enable(self):
		pass

	def disable(self):
		pass


class Tracker:
	"""Object that tracks the course of a ship and can predict its
	future location
	"""

	def __init__(self, target, sample_time=.250, max_samples=40):
		"""Track the target vessel, sampling its location every sample_time
		seconds, keeping at most max_samples.
		"""
		self.target = GroupSingle(target)
		self.samples = []
		self.next_sample = 0
		self.max_samples = max_samples
		self.sample_time = sample_time
		self.update()
	
	def update(self):
		"""Take a sample of the location of the target if the time since the
		last sample is over the sample time.  Otherwise do nothing. 
		Return True if a sample was actually taken, False if not.
		"""
		if game.time > self.next_sample and self.target:
			self.next_sample = game.time + self.sample_time * 1000
			self.samples.append(vector.to_tuple(self.target.sprite.position))
			while len(self.samples) > self.max_samples:
				self.samples.pop(0)
			return True
		else:
			return False
	
	def predict(self, time_ahead):
		"""Return the predicted location for the target at time_ahead
		seconds in the future. If the target no longer exists,
		return None.
		"""
		if not self.target:
			return None
		target = self.target.sprite
		v1 = target.velocity * min(time_ahead, self.sample_time * 3)
		if len(self.samples) <= 3:
			# with very few samples, just extrapolate ahead given
			# the target's current heading and velocity, limiting
			# time_ahead to a small value to minimize overshoots
			return target.position + v1
		# Determine the linear correlation of the samples
		# to see how to extrapolate the target's position
		top = 0.0
		bot_left = 0.0
		bot_right = 0.0
		startx, starty = self.samples[0]
		for x, y in self.samples:
			top += x * y
			bot_left += x * x
			bot_right += y * y
		correlation = abs(top / math.sqrt(bot_left * bot_right or float('nan')))
		endx, endy = self.samples[-1]
		v2 = vector.vector2(endx - startx, endy - starty) / (
			len(self.samples) * self.sample_time) * time_ahead
		if correlation > 0.5:
			# We have a linear path, interpolate between extrapolating 
			# along this path and extrapolating on the current velocity
			return target.position + (v2 * correlation) + (v1 * (1 - correlation))
		else:
			# The path is very non-linear, average the prediction with
			# the center point of the location samples to hedge against
			# a future non-linear trajectory
			hedge = vector.vector2(*self.samples[0]) + vector.vector2(
				endx - startx, endy - starty) / 2
			return target.position + (v1 + v2 + hedge) / 3


if __name__ == '__main__':
	"""Run tests if executed directly"""
	import sys, doctest
	failed, count = doctest.testmod()
	print 'Ran', count, 'test cases with', failed, 'failures'
	sys.exit(failed)
